import torch
from torch.nn import functional as F

def get_contrastive_loss(args, features1, features2):

    # Compute CL loss      
    target = torch.arange(args.num_samples).unsqueeze(0)
    intra_mask = (torch.eq(target, target.T).float()).to(args.device)     
    
    cos_sim_ij = F.cosine_similarity(features1[:,None,:], features2[None,:,:], dim=-1)
    cos_sim_ij = torch.div(cos_sim_ij, args.cl_temp)
    log_prob_ij = cos_sim_ij - torch.log((torch.exp(cos_sim_ij)).sum(1, keepdim=True))
    mean_log_prob_pos_ij = (intra_mask * log_prob_ij).sum(1) / intra_mask.sum(1)
    
    cos_sim_ji = F.cosine_similarity(features2[:,None,:], features1[None,:,:], dim=-1)
    cos_sim_ji = torch.div(cos_sim_ji, args.cl_temp)
    log_prob_ji = cos_sim_ji - torch.log((torch.exp(cos_sim_ji)).sum(1, keepdim=True))
    mean_log_prob_pos_ji = (intra_mask * log_prob_ji).sum(1) / intra_mask.sum(1)
    
    constrastive_loss = - (mean_log_prob_pos_ij.mean() + mean_log_prob_pos_ji.mean())
         
    return constrastive_loss
